-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Add adapt_checkpoint_hparams hook for customizing checkpoint hyperparameter loading #21408
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Add adapt_checkpoint_hparams hook for customizing checkpoint hyperparameter loading #21408
Conversation
…ameter loading Fixes Lightning-AI#21255 This commit adds the adapt_checkpoint_hparams() public method to LightningCLI, allowing users to customize hyperparameters loaded from checkpoints before they are used to instantiate model classes. This is particularly useful when using checkpoints from a TrainingModule with a different InferenceModule class that has different __init__ parameters. Problem: When loading a checkpoint trained with TrainingModule(lr=1e-3) into an InferenceModule() that doesn't accept 'lr' as a parameter, the CLI would fail during instantiation because it tries to pass all checkpoint hyperparameters to the new module class. Solution: Added adapt_checkpoint_hparams() hook that is called in _parse_ckpt_path() after loading checkpoint hyperparameters but before applying them. Users can override this method to: - Remove training-specific hyperparameters (e.g., lr, weight_decay) - Modify _class_path for subclass mode - Transform hyperparameter names/values - Completely disable checkpoint hyperparameters by returning {} Example usage: class MyCLI(LightningCLI): def adapt_checkpoint_hparams(self, checkpoint_hparams): checkpoint_hparams.pop('lr', None) checkpoint_hparams.pop('weight_decay', None) return checkpoint_hparams This approach is preferable to: - Disabling checkpoint loading entirely (loses valuable hyperparameter info) - Adding CLI arguments (deviates from Trainer parameter pattern) - Modifying private methods (breaks encapsulation) The hook provides maximum flexibility while maintaining backward compatibility (default implementation returns hyperparameters unchanged).
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR adds a public adapt_checkpoint_hparams() hook to LightningCLI that enables users to customize hyperparameters loaded from checkpoints before model instantiation. This addresses the issue of loading checkpoints across different module classes (e.g., from TrainingModule to InferenceModule) where incompatible __init__ parameters would otherwise cause failures.
Key Changes:
- Added
adapt_checkpoint_hparams()public method with comprehensive documentation - Integrated the hook into
_parse_ckpt_path()to allow customization before hyperparameter application - Maintained backward compatibility with a default no-op implementation
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
src/lightning/pytorch/cli.py
Outdated
| def adapt_checkpoint_hparams(self, checkpoint_hparams: Dict[str, Any]) -> Dict[str, Any]: | ||
| """Adapt checkpoint hyperparameters before instantiating the model class. | ||
| This method allows for customization of hyperparameters loaded from a checkpoint when | ||
| using a different model class than the one used for training. For example, when loading | ||
| a checkpoint from a TrainingModule to use with an InferenceModule that has different | ||
| ``__init__`` parameters, you can remove or modify incompatible hyperparameters. | ||
| Args: | ||
| checkpoint_hparams: Dictionary of hyperparameters loaded from the checkpoint. | ||
| Returns: | ||
| Dictionary of adapted hyperparameters to be used for model instantiation. | ||
| Example:: | ||
| class MyCLI(LightningCLI): | ||
| def adapt_checkpoint_hparams(self, checkpoint_hparams: Dict[str, Any]) -> Dict[str, Any]: | ||
| # Remove training-specific hyperparameters not needed for inference | ||
| checkpoint_hparams.pop("lr", None) | ||
| checkpoint_hparams.pop("weight_decay", None) | ||
| return checkpoint_hparams | ||
| Note: | ||
| If subclass module mode is enabled and ``_class_path`` is present in the checkpoint | ||
| hyperparameters, you may need to modify it as well to point to your new module class. | ||
| """ | ||
| return checkpoint_hparams |
Copilot
AI
Dec 6, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The new adapt_checkpoint_hparams() hook lacks test coverage. Given that tests/tests_pytorch/test_cli.py contains comprehensive tests for checkpoint loading functionality (e.g., test_lightning_cli_ckpt_path_argument_hparams and test_lightning_cli_ckpt_path_argument_hparams_subclass_mode), tests should be added to verify:
- The hook is called when loading checkpoint hyperparameters
- Modifications made in the hook are applied correctly
- Returning an empty dict properly skips checkpoint hyperparameter loading
- The hook works in both regular and subclass modes
src/lightning/pytorch/cli.py
Outdated
| else: | ||
| self.config = parser.parse_args(args) | ||
|
|
||
| def adapt_checkpoint_hparams(self, checkpoint_hparams: Dict[str, Any]) -> Dict[str, Any]: |
Copilot
AI
Dec 6, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use lowercase dict instead of Dict for type annotations to align with the modern Python 3.9+ style used throughout this file. Change Dict[str, Any] to dict[str, Any] in both the parameter and return type annotations.
src/lightning/pytorch/cli.py
Outdated
| Example:: | ||
| class MyCLI(LightningCLI): | ||
| def adapt_checkpoint_hparams(self, checkpoint_hparams: Dict[str, Any]) -> Dict[str, Any]: |
Copilot
AI
Dec 6, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use lowercase dict instead of Dict for type annotations to align with the modern Python 3.9+ style used throughout this file. Change Dict[str, Any] to dict[str, Any] in both the parameter and return type annotations.
mauvilsa
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is looking good. However, the subcommand parameter is missing. Also please add unit tests.
src/lightning/pytorch/cli.py
Outdated
| else: | ||
| self.config = parser.parse_args(args) | ||
|
|
||
| def adapt_checkpoint_hparams(self, checkpoint_hparams: Dict[str, Any]) -> Dict[str, Any]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| def adapt_checkpoint_hparams(self, checkpoint_hparams: Dict[str, Any]) -> Dict[str, Any]: | |
| def adapt_checkpoint_hparams(self, subcommand: str, checkpoint_hparams: Dict[str, Any]) -> Dict[str, Any]: |
As mentioned in my proposal, the method should receive a subcommand parameter.
src/lightning/pytorch/cli.py
Outdated
| checkpoint_hparams.pop("lr", None) | ||
| checkpoint_hparams.pop("weight_decay", None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In this example, removing lr and weight_decay should not be done if the subcommand is fit.
src/lightning/pytorch/cli.py
Outdated
| return | ||
|
|
||
| # Allow customization of checkpoint hyperparameters via adapt_checkpoint_hparams hook | ||
| hparams = self.adapt_checkpoint_hparams(hparams) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| hparams = self.adapt_checkpoint_hparams(hparams) | |
| hparams = self.adapt_checkpoint_hparams(subcommand, hparams) |
…ook and add tests - Update adapt_checkpoint_hparams signature to include subcommand parameter allowing context-aware customization of checkpoint hyperparameters - Change type annotations to use lowercase dict (Python 3.9+ style) - Update docstring with subcommand parameter documentation - Add example showing conditional logic based on subcommand - Add comprehensive unit tests: - test_adapt_checkpoint_hparams_hook: Tests that hook is called and modifications applied - test_adapt_checkpoint_hparams_hook_empty_dict: Tests disabling checkpoint hparams loading - Tests cover both regular and subclass modes
for more information, see https://pre-commit.ci
|
Thanks for the response. I already updated Also added def adapt_checkpoint_hparams(self, subcommand: str, checkpoint_hparams: dict[str, Any]) -> dict[str, Any]:
if subcommand != "fit":
checkpoint_hparams.pop("lr", None) # Remove training params for inference
return checkpoint_hparamsI also included 2 comprehensive tests:
|
- Split method signature across multiple lines to stay within 120 char limit - Improves code readability in documentation example
mauvilsa
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is looking good. But the two tests fail. You will need to implement a new Model class for these tests.
tests/tests_pytorch/test_cli.py
Outdated
| assert cli.model.layer.out_features == 4 | ||
|
|
||
|
|
||
| def test_adapt_checkpoint_hparams_hook(cleandir): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| def test_adapt_checkpoint_hparams_hook(cleandir): | |
| def test_adapt_checkpoint_hparams_hook_pop_keys(cleandir): |
tests/tests_pytorch/test_cli.py
Outdated
| def add_arguments_to_parser(self, parser): | ||
| parser.link_arguments("model.out_dim", "model.hidden_dim", compute_fn=lambda x: x * 2) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| def add_arguments_to_parser(self, parser): | |
| parser.link_arguments("model.out_dim", "model.hidden_dim", compute_fn=lambda x: x * 2) |
Linking of arguments is not relevant to test this hook. Better to not have it to avoid distraction.
tests/tests_pytorch/test_cli.py
Outdated
| def add_arguments_to_parser(self, parser): | ||
| parser.link_arguments("model.out_dim", "model.hidden_dim", compute_fn=lambda x: x * 2) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| def add_arguments_to_parser(self, parser): | |
| parser.link_arguments("model.out_dim", "model.hidden_dim", compute_fn=lambda x: x * 2) |
Linking of arguments is not relevant to test this hook. Better to not have it to avoid distraction.
tests/tests_pytorch/test_cli.py
Outdated
| # First, create a checkpoint | ||
| cli_args = ["fit", "--model.out_dim=3", "--trainer.max_epochs=1"] | ||
| with mock.patch("sys.argv", ["any.py"] + cli_args): | ||
| cli = AdaptHparamsEmptyCLI(BoringCkptPathModel) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The test fails because of BoringCkptPathModel has a module torch.nn.Linear(32, out_dim). If the out_dim is changed, then there is a tensor size mismatch.
Instead of using BoringCkptPathModel, implement a new class for these two tests, that just sets an attribute that can be asserted after instantiation.
… size mismatch in tests
for more information, see https://pre-commit.ci
|
@mauvilsa Thanks for the detailed feedback! I've successfully implemented your suggestion. You correctly identified that the tests were failing due to tensor size mismatches. The original tests used I created a new, simple model class specifically for testing the hook:
Tests Updated
|
mauvilsa
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@arrdel your tests still fail. It would be better if you run the tests locally and be sure that all works correctly before pushing. There are instruction on how to do that. More or less how I do it in linux is (I don't remember exactly):
- Create virtual env
- Install lightning like
export PACKAGE_NAME=lightning ; pip install -e ".[test]" - Install jsonargparse like
pip install "jsonargparse[signatures]"
Then to run only the CLI tests, I do:
export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
pytest -v tests/tests_pytorch/test_cli.py
| hyperparameters, you may need to modify it as well to point to your new module class. | ||
| """ | ||
| return checkpoint_hparams |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not really related to this new feature, but there is also my comment in #21116 (comment). Nobody responded to it. Maybe by default fit should not use the hparams from the checkpoint?
Also this could be related #21255 (comment)
I am not really sure what to do here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@arrdel any comment on this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, it seems #21455 would fix this comment, I think.
Removed redundant method implementations since BoringModel provides them.
|
@mauvilsa Thanks for the feedback! I've made the fix you suggested:
|
The test was asserting hidden_dim==3 but only passing out_dim=3. Since hidden_dim defaults to 16 and there's no argument linking, the assertion failed. Now we explicitly pass --model.hidden_dim=6.
Codecov Report✅ All modified and coverable lines are covered by tests.
Additional details and impacted files@@ Coverage Diff @@
## master #21408 +/- ##
=========================================
- Coverage 87% 79% -8%
=========================================
Files 269 267 -2
Lines 23804 24009 +205
=========================================
- Hits 20626 18957 -1669
- Misses 3178 5052 +1874 |
mauvilsa
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great to see that the tests are now successful. I still have these two comments, but overall it looks good so I approve now. Anyway, my approval is not that useful since still someone from the lightning team needs to approve.
| checkpoint_hparams.pop("out_dim", None) | ||
| checkpoint_hparams.pop("hidden_dim", None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
From a testing perspective there is no difference between out_dim and hidden_dim. It might be better if one is popped and the other not, so that both cases are tested?
| hyperparameters, you may need to modify it as well to point to your new module class. | ||
| """ | ||
| return checkpoint_hparams |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@arrdel any comment on this?
|
For some reason I am unable to resolve my old comments that have been addressed already. |
What does this PR do?
Fixes #21255
This PR adds a public
adapt_checkpoint_hparams()hook toLightningCLIthat allows users to customize hyperparameters loaded from checkpoints before they are used to instantiate model classes. This solves the problem of loading checkpoints across different module classes (e.g., fromTrainingModuletoInferenceModule).Problem
When using
LightningCLIwith checkpoints, hyperparameters saved during training are automatically loaded and applied when running other subcommands (test, predict, etc.). This is convenient when using the same module class, but fails when using a different class with incompatible__init__parameters.Example scenario:
Running
cli predict --ckpt_path checkpoint.ckptwithInferenceModulefails because the CLI tries to passlr=1e-3from the checkpoint toInferenceModule.__init__().Solution
Added
adapt_checkpoint_hparams()public method that users can override to customize loaded hyperparameters:Implementation Details
adapt_checkpoint_hparams()public method inLightningCLI_parse_ckpt_path()to call the hook after loading but before applying hyperparametersWhy This Approach?
As discussed in #21255, this is superior to alternatives:
hidden_dim)Testing
The implementation:
_class_pathmodification when neededExample Use Cases
Remove training-only parameters:
Change module class in subclass mode:
Disable all checkpoint hyperparameters:
Does your PR introduce any breaking changes?
No, this is a purely additive change. The default implementation returns hyperparameters unchanged, preserving existing behavior.
Before submitting
PR review
cc: @mauvilsa @ziw-liu
📚 Documentation preview 📚: https://pytorch-lightning--21408.org.readthedocs.build/en/21408/